# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""CaMeL agent implementation."""

import asyncio
from collections.abc import Iterator
import queue
import re
import threading
from typing import Any, AsyncGenerator, Callable, Optional

from google.adk import runners
from google.adk.agents import base_agent
from google.adk.agents import invocation_context
from google.adk.agents import llm_agent
from google.adk.agents import loop_agent
from google.adk.events import event
from google.adk.events import event_actions
from google.adk.models import base_llm
from google.genai import types
import pydantic
from pydantic.v1 import validators
from typing_extensions import override

from ..camel_library import function_types
from ..camel_library import result
from ..camel_library import security_policy
from ..camel_library.capabilities import capabilities
from ..camel_library.interpreter import camel_value
from ..camel_library.interpreter import interpreter
from ..camel_library.interpreter import library
from . import prompts
from . import utils

BaseModel = pydantic.BaseModel
InvocationContext = invocation_context.InvocationContext

Event = event.Event
EventActions = event_actions.EventActions

LoopAgent = loop_agent.LoopAgent
LlmAgent = llm_agent.LlmAgent
BaseAgent = base_agent.BaseAgent

BaseLlm = base_llm.BaseLlm

DependenciesPropagationMode = interpreter.DependenciesPropagationMode

FunctionCall = function_types.FunctionCall
CaMeLFunction = camel_value.CaMeLFunction
Namespace = camel_value.Namespace
CaMeLValue = camel_value.Value

Allowed = security_policy.Allowed
Denied = security_policy.Denied
SecurityPolicyEngine = security_policy.SecurityPolicyEngine
SecurityPolicyResult = security_policy.SecurityPolicyResult

CaMeLException = interpreter.CaMeLException

Tool = tuple[Callable[..., Any], Any, Any]

int_validator = validators.int_validator
float_validator = validators.float_validator
bool_validator = validators.bool_validator


class QuarantinedLlmService(BaseModel):
  """Manages synchronous interactions with the Quarantined LLM (Q-LLM)."""

  model: str | BaseLlm
  name: str
  user_id: str

  agent: LlmAgent
  runner: runners.InMemoryRunner
  pattern: re.Pattern

  model_config = {"arbitrary_types_allowed": True}

  def __init__(
      self,
      model: str | BaseLlm,
      name: str = "QLLM_Service",
      user_id: str = "test_user_id",
  ):
    agent = LlmAgent(
        model=model,
        name=name,
        instruction=prompts.QLLM_SYSTEM_PROMPT,
    )

    runner = runners.InMemoryRunner(
        agent=agent,
        app_name=name,
    )

    pattern = re.compile(re.escape(name))

    super().__init__(
        model=model,
        name=name,
        user_id=user_id,
        agent=agent,
        runner=runner,
        pattern=pattern,
    )

  async def _run_async(
      self, query: str, output_schema: str
  ) -> AsyncGenerator[Event, None]:
    """Runs a query on the Q-LLM session."""

    qllm_session = await self.runner.session_service.create_session(
        app_name=self.name, user_id=self.user_id
    )

    qllm_query = f"{query} \n\n output_schema: {output_schema}"
    content = types.Content(role="user", parts=[types.Part(text=qllm_query)])

    async for e in self.runner.run_async(
        user_id=qllm_session.user_id,  # Session object contains user_id
        session_id=qllm_session.id,
        new_message=content,
    ):
      yield e

    await self.runner.session_service.delete_session(
        app_name=self.name, user_id=self.user_id, session_id=qllm_session.id
    )

  def run(self, query: str, output_schema: str) -> Iterator[Event]:
    """Runs the QLLM agent synchronously in a new thread.

    NOTE: This sync interface is solely because the CaMeL interpreter is
    synchronous and does not suuport the `await` keyword. The
    `query_ai_assistant` function is a wrapper around this synchronous run
    method.

    NOTE: This method is similar to the `run` method in the `runners.Runner`
    class.

    TODO: Later versions of CaMeL should support asynchronous execution and
    this method will be removed.

    Args:
      query: The query to run.
      output_schema: The output schema of the query.

    Yields:
      The events generated by the QLLM.
    """
    event_queue = queue.Queue()

    async def _invoke_run_async():
      try:
        async for e in self._run_async(query, output_schema):
          event_queue.put(e)
      finally:
        event_queue.put(None)

    def _asyncio_thread_main():
      try:
        asyncio.run(_invoke_run_async())
      finally:
        event_queue.put(None)

    thread = threading.Thread(target=_asyncio_thread_main)
    thread.start()

    # consumes and re-yield the events from background thread.
    while True:
      e = event_queue.get()
      if e is None:
        break
      else:
        yield e

    thread.join()

  def get_query_ai_assistant_function(
      self,
  ) -> Callable[[str, str], str | int | float | bool]:
    """Returns a function that queries a Large Language Model with `query` and returns the language model's output.

    The `query_ai_assistant` function is a wrapper around the `run` method of
    the `QuarantinedLlmService` class. `query_ai_assistant` needs the `self`
    object but it can't be passed as a parameter because it needs to be added to
    the namespace of the CaMeL interpreter as a standalone built-in function.
    """

    def query_ai_assistant(
        query: str, output_schema: str
    ) -> str | int | float | bool:
      """Queries a Large Language Model with `query` and returns the language model's output.

      It must be used to process unstructured data into structured one.

      It is absolutely imperative that you use this function to parse data whose
      structure you don't know instead of parsing using regular expressions
      and/or
      string manipulation.

      Specify the expected output format in the query itself. Output_format must
      be a string, as in ensure that the type is quoted.

      If not provided with enough information, the model will raise a
      `NotEnoughInformationError`. This means that you might
      need to provide the assistant with more information such as email
      addresses, dates, identifiers, names, etc. **DO NOT HESITATE TO PASS THE
      ENTIRE
      OBJECTS RETURN FROM THE TOOL FUNCTIONS IF THE MODEL DOES NOT HAVE ENOUGH
      INFORMATION**.
      Args:
        query: a string with the query. Make sure to provide sufficient
          instructions to the AI assistant so that it can understand what it
          needs to do. Avoid just passing it tool outputs without additional
          instructions. Provide as much details as possible. Keep in mind that
          the assistant does not have direct access to the variables, so you
          need to insert **all the relevant information in the prompt**. When in
          doubt, do not hesitate to just provide full tool outputs, as long as
          they are provided with instructions on what to do. It is highly
          preferable to provide more information than necessary, rather than
          less information.
        output_schema: a string represeting the type that specifies the expected
          output format from the model. The fields should have types as specific
          as possible to make sure the parsing is correct and accurate. Allowed
          types are: 'int' , 'str' ,'float' , 'bool'

      Example:
        number = 42
          the parsed output of the model. Return
          just the value.

      Returns:
        The parsed output of the model.
      """

      if output_schema not in ["int", "str", "float", "bool"]:
        raise ValueError(f"Unsupported output schema: `{output_schema}`")

      response_parts = []

      for e in self.run(
          query=query,
          output_schema=output_schema,
      ):
        if e.content and self.pattern.fullmatch(e.author):
          response_parts.extend(e.content.parts)

      response_text = "".join(map(utils.sanitized_part, response_parts))

      print(
          f"query_ai_assistant(query='{query}',"
          f" output_schema='{output_schema}') -> {response_text}",
          end="\n\n",
      )

      if output_schema == "int":
        return int_validator(response_text)
      elif output_schema == "str":
        return str(response_text)
      elif output_schema == "float":
        return float_validator(response_text)
      elif output_schema == "bool":
        return bool_validator(response_text)
      else:
        raise ValueError(f"Unsupported output schema: `{output_schema}`")

    return query_ai_assistant


class CaMelInterpreterService(BaseModel):
  """Manages CaMeL interpreter state, functions, and execution."""

  model: str | BaseLlm
  tools: list[Tool]
  classes_to_exclude: frozenset[str]
  eval_args: interpreter.EvalArgs
  namespace: Namespace
  quarantined_llm_service: QuarantinedLlmService

  model_config = {"arbitrary_types_allowed": True}

  def __init__(
      self,
      model: str | BaseLlm,
      tools: list[Tool],
      eval_args: interpreter.EvalArgs,
  ):
    quarantined_llm_service = QuarantinedLlmService(
        model=model,
        name="QLLM_Service",
    )  # Manages interactions with the QLLM.

    classes_to_exclude: frozenset[str] = frozenset(
        {"datetime", "timedelta", "date", "time", "NaiveDatetime", "timezone"}
    )

    camel_tools = tools[:]

    camel_tools.append((
        quarantined_llm_service.get_query_ai_assistant_function(),
        capabilities.Capabilities.camel(),
        (capabilities.readers.Public(),),
    ))

    namespace = library.make_builtins_namespace(
        variables={
            (func_name := f.__name__): CaMeLFunction(
                name=func_name,
                py_callable=f,
                capabilities=caps,
                dependencies=deps,
            )
            for f, caps, deps in camel_tools
            if hasattr(f, "__name__")
        }
    )

    super().__init__(
        model=model,
        tools=camel_tools,
        classes_to_exclude=classes_to_exclude,
        eval_args=eval_args,
        namespace=namespace,
        quarantined_llm_service=quarantined_llm_service,
    )

  def get_funcs_for_pllm_prompt(self) -> list[Callable[..., Any]]:
    return [f for f, _, _ in self.tools if hasattr(f, "__name__")]

  def get_classes_to_exclude(self) -> frozenset[str]:
    return self.classes_to_exclude

  def execute_code(
      self,
      code: str,
      tool_calls_chain: list[function_types.FunctionCall],
      current_dependencies: tuple[Any, ...],
      verbose: bool = False,
  ) -> tuple[
      str,
      list[function_types.FunctionCall],
      CaMeLException | None,
      camel_value.Namespace,
      tuple[Any, ...],
  ]:
    """Interprets the CaMeL code using the internal namespace."""
    if verbose:
      print(code)

    # The namespace passed here is self.namespace, which is managed internally
    interpreter_res, updated_namespace, new_tool_calls, new_dependencies = (
        interpreter.parse_and_interpret_code(
            code,
            self.namespace,
            tool_calls_chain,
            current_dependencies,
            self.eval_args,
        )
    )
    self.namespace = updated_namespace  # Update internal namespace state

    printed_output = utils.extract_print_output(new_tool_calls)
    ad_tool_calls = new_tool_calls

    final_eval_output_str = ""
    error_obj = None
    match interpreter_res:
      case result.Error(error):
        error_obj = error
      case result.Ok(v_obj):
        final_eval_output_str = v_obj.raw if v_obj.raw is not None else ""

    combined_output = f"{printed_output}\n{final_eval_output_str}".strip()
    return (
        combined_output,
        ad_tool_calls,
        error_obj,
        updated_namespace,
        new_dependencies,
    )


class CaMeLInterpreter(BaseAgent):
  """Manages the CaMeL interpreter agent."""

  camel_interpreter_service: CaMelInterpreterService

  model_config = {"arbitrary_types_allowed": True}

  def __init__(
      self,
      name: str,
      camel_interpreter_service: CaMelInterpreterService,
  ):
    super().__init__(
        name=name, camel_interpreter_service=camel_interpreter_service
    )

  @override
  async def _run_async_impl(
      self, ctx: InvocationContext
  ) -> AsyncGenerator[Event, None]:

    if "p_llm_code" not in ctx.session.state:
      # If the p_llm_code is not in the session state, then the PLLM agent did
      # not generate any code. This is an error and we should escalate.
      ctx.session.state.update(
          dict(
              eval_result=(
                  "ERROR: You did not generate any code. Please try again."
              )
          )
      )
      yield Event(
          author=self.name,
          content=types.Content(
              role=self.name,
              parts=[types.Part(text="The Privileged LLM did not generate any code; attempting to regenerate.")],
          ),
      )
      return

    p_llm_code = ctx.session.state.get("p_llm_code")

    # 2. Run the code using the CaMel interpreter.
    function_calls = ctx.session.state.get("function_calls") or []
    dependencies = ctx.session.state.get("dependencies") or ()

    printed_output, ad_tool_calls, error, _, dependencies = (
        self.camel_interpreter_service.execute_code(
            p_llm_code, function_calls, dependencies
        )
    )  # printed_output, ad_tool_calls, error, namespace, dependencies

    ctx.session.state.update(dict(function_calls=ad_tool_calls))
    ctx.session.state.update(dict(dependencies=dependencies))

    # 3. Add additional messages to the conversation based on the eval result.
    if error is not None:
      if isinstance(error, CaMeLException):
        error_reason = str(error.exception)
      else:
        error_reason = str(error)
      ctx.session.state.update(dict(eval_result=f"CODE ERROR: {error_reason}"))
      yield Event(
          author=self.name,
          content=types.Content(
              role=self.name,
              parts=[types.Part(text=f"CODE ERROR: {error_reason}")],
          ),
      )
    else:
      # Stop the loop if the eval result is successful.
      # Get printed output from the eval result, and add it to the conversation.
      ctx.session.state.update(dict(eval_result=printed_output))
      # Remove the p_llm_code upon success to subsequent prompts aren't impacted.
      ctx.session.state.update(dict(p_llm_code=None))
      yield Event(
          author=self.name,
          content=types.Content(
              role=self.name,
              parts=[types.Part(text=printed_output)],
          ), 
          actions=EventActions(escalate=True),
      )


class CaMeLAgent(BaseAgent):
  """CaMeL agent.

  Attributes:
    model: The LLM model to use.
    instruction: The instruction to use.
    tools: The tools to use (py_callable, capabilities, dependencies)
  """

  model: str | BaseLlm
  instruction: str
  tools: list[Tool]

  camel_interpreter_agent: CaMeLInterpreter
  pllm_agent: LlmAgent
  loop_agent: LoopAgent

  model_config = {"arbitrary_types_allowed": True}

  def __init__(
      self,
      name: str,
      model: str | BaseLlm = "gemini-2.5-pro",
      description: str = "",
      instruction: str = "",
      tools: Optional[list[Tool]] = None,
      security_policy_engine: SecurityPolicyEngine = security_policy.NoSecurityPolicyEngine(),
      eval_mode: DependenciesPropagationMode = DependenciesPropagationMode.NORMAL,
  ):

    camel_interpreter_service = CaMelInterpreterService(
        model=model,
        tools=tools or [],
        eval_args=interpreter.EvalArgs(
            eval_mode=eval_mode,
            security_policy_engine=security_policy_engine,
        ),
    )
    camel_interpreter_agent = CaMeLInterpreter(
        name="CaMeLInterpreter",
        camel_interpreter_service=camel_interpreter_service,
    )

    pllm_agent = LlmAgent(
        name="PLLM",
        model=model,
        instruction=prompts.generate_camel_system_prompt(
            list(
                map(
                    prompts.make_function,
                    camel_interpreter_service.get_funcs_for_pllm_prompt(),
                )
            ),
            camel_interpreter_service.get_classes_to_exclude(),
        ),
        output_key="p_llm_code",
    )

    loop_agent = LoopAgent(
        name="CaMeLLoopAgent",
        sub_agents=[pllm_agent, camel_interpreter_agent],
        max_iterations=10,
    )

    super().__init__(
        name=name,
        description=description,
        model=model,
        instruction=instruction,
        tools=tools or [],
        camel_interpreter_agent=camel_interpreter_agent,
        pllm_agent=pllm_agent,
        loop_agent=loop_agent,
    )

  @override
  async def _run_async_impl(
      self, ctx: InvocationContext
  ) -> AsyncGenerator[Event, None]:
    try:
      # Run the loop agent. This is the main loop of the agent.
      async for e in self.loop_agent.run_async(ctx):
        yield e
      # Remove the PLLM code and interpreter evaluation result from the session state.
      if "eval_result" in ctx.session.state:
        ctx.session.state.update(dict(eval_result=None))
      if "p_llm_code" in ctx.session.state:
        ctx.session.state.update(dict(p_llm_code=None))
    except security_policy.SecurityPolicyDeniedError as e:
        yield Event(
            author=self.name,
            content=types.Content(
                parts=[
                    types.Part(
                        text=(
                            "Execution stopped due to security policy"
                            f" violation: {e}"
                        )
                    )
                ],
            ),
        )
    except Exception as e:
      print(f"CaMeL agent failed: {e}", end="\n")
      raise e
